This challenge is about a simple server send out RSA-encrypted data of a flag. Here we are going to use a classic exploit of Hastad's attack, but with a small twist of optimising code base :)
We have a hidden flag value M, then is encrypted to become C by:
where N = pq, where p and q are two big (512-bit) primes and N > M. This makes N 1024-bit. The server generates these C and N, then send them to us, the client as many times as we like, each time a different C and a different N > M composed by entirely different ps and qs.
Notice M does not change all the time, nor it is being randomly padded. We could use the Hastad's attack by getting 65537 different Cs, (C1, C2, ...) and 65537 different Ns (N1, N2, ...) and recover the actual M65537 by:
If we consider each of the N composed by different prime numbers, then lcm(N1, N2, ...) = N1N2...N65537 so
Since
To recover M, all we need is to take the 65537-th root of M65537, which can be done very fast using Sagemath, and we're done!
Collecting 65537 Cs and Ns is doable. Since for each request the server sends us 100 C-N pairs, we just need to send (65537//100 + 1) = 656 requests in total. Took a while (about an hour), but definitely feasible.
xfrom pwn import remote
from tqdm import tqdm, trange
##### socket #####
host = '34.136.227.174'
port = int(25455)
def recving():
ns = []
cs = []
io = remote(host, port)
for _ in range(100):
data = io.recvline().split(b':::')
ns.append(int(data[0]))
cs.append(int(data[-1]))
io.close()
return cs, ns
##### main #####
# Get results
cs, ns = [], []
for _ in trange(65537 // 100 + 1):
result = recving()
cs += result[0]
ns += result[1]
open('c.py', 'w').write(f'cs = {cs}')
open('n.py', 'w').write(f'ns = {ns}')
After we have our values c and n in two files: c.py and n.py.
From that, we just need to solve this by:
xxxxxxxxxx
from sage.all import *
from Crypto.Util.number import *
from c import c
from n import n
m_65537 = crt(c, n)
print(long_to_bytes(Integer(m_65537).nth_root(65537)))
So we waited...
One hour... Two hours... Three hours... (actually I give up after 10 mins, this is just for dramatic purposes, but my friend Timmy run for another 2 hours :V he could confirm)
And it did not give us any output.
Why did it take so long???
Here's the snippet of code from Sage's implementation of crt(), function CRT_list() that handles when the inputs to crt() are lists.
x
### Snippet from CRT_list()
x = v[0]
m = moduli[0]
from sage.arith.functions import lcm
for i in range(1, len(v)):
x = CRT(x,v[i],m,moduli[i])
m = lcm(m,moduli[i])
return x % m
Basically, what this function does is it builds up the calculation like so:
The result of the algorithm is xlen(C).
The algorithm works fine, it gets the job done. The problem lies when the number of values inputed at CRT is big.
To demonstrate what I mean, let's dig in to see what happens in each loop of that CRT_list() function.
x0 has 1024 bits. Let's replace this number of bits with b and assume that the time taken to write data to a b-bit number is Tb, and it scales up linearly with b.
After each loop, we need to create a new variable xi with +b bits more than xi-1 and write data to it. So in time, the number of bits of xi will be like this:
The total time taken to generate variables and write data will be like this:
As we can see, for b = 1024, the time taken to write one number is neglectable, but 2 billions of similar writes would make a HUGE different.
Let's analyse the runtime in O-notation. Replace 65537, the amount of numbers used in CRT, with k we have:
Not a very good runtime to be honest ✌️
And I haven't mention the runtime of the multiply algorithm between 2 numbers, which has a more dominant runtime. It has O(blog23) with respect to the number of bits, and is used a lot during crt. Here I will stick to the write-data algorithm as it is easier to elaborate 😗
Instead of doing CRT sequentially from left to right each loop like Sage's algorithm, what we did in each loop is doing CRT for each 2 numbers in the array. After k/2 runs we create 2 new list of k/2 2b-bit remainders and modulus. Then we do the same thing again, yield 2 new list of k/4 4b-bit remainders and modulus... Until the list of remainder has only one element 😗
At the first k/2 runs, the runtime for write data to k/2 2b-bit number is:
The next n/4 run yields the runtime of:
And so on...
Since the number of runs divides by 2 each loop, we have in total log(k) loops, which yield the total runtime for creating and writing memory of:
Comparing to the previous O(k2) algorithm, we have achieved a speed up of about (for k=65537):
Again, this is only the speed up ratio for creating and writing new memory, for speeding up in multiplication and other more time-consuming activities in crt(), it's different, but I guess we could just expect some value around this.
With that in mind, let's implement the code (btw for some reason the algorithm runs faster when we consider in each loop doing CRT for every 8 numbers instead of 2):
xxxxxxxxxx
from sage.all import *
from Crypto.Util.number import *
from c import c
from n import n
c = c[:65537]
n = n[:65537]
"""
/function/ CRT_():
"" Purpose:
Function that calculates CRT on chunks of <SEG_SIZE> numbers in the array
rather than the whole array at once.
Works pretty nice with big array and (probably) big number.
Built on the base of Sage's crt() function.
"" Args:
r: List of remainders.
m: List of modulus.
SEG_SIZE=8: Number of values that CRT should work on once at a time.
debug=False: Print out some debug data.
"""
def CRT_(r, m, SEG_SIZE=8, debug=False):
if debug:
print(f'[ i ] Calculate CRT with chunk size {SEG_SIZE}...')
print(f'[ i ] Start loop with len = {len(r)}')
while len(r) != 1:
newR = []
newM = []
for i in range(0, len(r), SEG_SIZE):
if len(r) - i == 1:
newR.append(r[i])
newM.append(m[i])
else:
crt_ = crt(r[i:i+SEG_SIZE], m[i:i+SEG_SIZE])
prod = 1
for _m in m[i:i+SEG_SIZE]:
prod *= _m
newR.append(crt_)
newM.append(prod)
r = newR
m = newM
if debug:
print(f'[ i ] Update loop with len = {len(r)}')
if debug:
print(f'[ i ] Finished :D')
return r[0]
# Got CRT
# m_65537 = crt(c, n) # <- too slow !
m_65537 = CRT_(c, n, debug=True)
print(long_to_bytes(Integer(m_65537).nth_root(65537)))
And the algorithm only took us 3 minutes to run, which is definitely faster than infinity :)
It's great that my writeup got so many praises from the author, cothan who is also a long-time Crypto player 😄 I'm crying in happiness right now TwT. Although I came up with the algorithm, but I got stuck and gave up half-way through. It was Timmy who stick to my idea and implement it to solve it (thank you very^n where n goes to infinity much) and get juicy points 😗 And he also noticed for somehow, grouping 8 crt()s together seems to be a better choice than 2(?)
cothan also mentioned that applying parallelism to the algorithm is a very good way to improve the algorithm.
Now, I'm not familiar with Julia much, but parallelism is do-able with so much ease. The key here is noticing that in each loop, each result of crt() calls does not affect each other until the next loop. This implies that we can have multiple calls to crt() running simultaneously in a loop, which will improve a lot our performance. If it's coded in Julia instead of Python, it would have been better. For now, let's stick to Python. With a few twitches to the code, here's what we've got:
xxxxxxxxxx
from sage.all import *
from Crypto.Util.number import *
import concurrent.futures
from c import c
from n import n
c = c[:65537]
n = n[:65537]
"""
/function/ CRT_():
"" Purpose:
Function that calculates CRT on chunks of <SEG_SIZE> numbers in the array
rather than the whole array at once.
Works pretty nice with big array and (probably) big number.
Test: 65537 1024-bit numbers on a single core -> 3 mins 15 seconds.
Test: 65537 1024-bit numbers on 8 cores -> 1 mins 2 seconds.
Built on the base of Sage's crt() function.
"" Args:
r: List of remainders.
m: List of modulus.
SEG_SIZE=12: Number of values that CRT should work on once at a time.
debug=False: Print out some debug data.
PROCESS_NO=8: Number of processes to run at the same time.
"""
def CRT_(r, m, SEG_SIZE=12, debug=False, PROCESS_NO=8):
assert len(r) == len(m)
if debug:
print(f'[ i ] Calculate CRT with chunk size {SEG_SIZE}...')
print(f'[ i ] Start loop with len = {len(r)}')
with concurrent.futures.ProcessPoolExecutor(PROCESS_NO) as executor:
while len(r) != 1:
newR = []
newM = []
for i in range(0, len(r), SEG_SIZE):
if len(r) - i == 1:
newR.append(Integer(r[i]))
newM.append(Integer(m[i]))
else:
newR.append(executor.submit(crt, r[i:i+SEG_SIZE], m[i:i+SEG_SIZE]))
newM.append(executor.submit(prod, m[i:i+SEG_SIZE]))
# Obtain processes' results :3
for i in range(len(newR)):
if not isinstance(newR[i], Integer):
newR[i] = newR[i].result()
newM[i] = newM[i].result()
r = newR
m = newM
if debug:
print(f'[ i ] Update loop with len = {len(r)}')
if debug:
print(f'[ i ] Finished :D')
return r[0]
# Got CRT
# m_65537 = crt(c, n) # <- too slow !
m_65537 = CRT_(c, n, debug=True, SEG_SIZE=8)
print(long_to_bytes(Integer(m_65537).nth_root(65537)))
And this is the result on my 8-core machine 👯👯👯👯 wow!! 3 times faster.
... And we can do even better than this! You see, Sage's crt() only returns the result of the new remainder, it doesn't return the new moduli, although it also calculate the new moduli on the run as well! This makes us having to recalculate it later just to append to our new moduli array, which seems redundant to be honest.
x
### Snippet from CRT_list()
x = v[0]
m = moduli[0]
from sage.arith.functions import lcm
for i in range(1, len(v)):
x = CRT(x,v[i],m,moduli[i])
m = lcm(m,moduli[i])
return x % m # Should also return m
Sage's code base is already nice, so I just pull the functions crt() and CRT_list() out from the source and modify it a little bit so that they could return the new moduli as well:
x
from sage.all import *
from sage.structure.coerce import py_scalar_to_element
from sage.arith.functions import lcm
from Crypto.Util.number import *
import concurrent.futures
"""
/function/ crt_2():
"" Purpose:
Modified from Sage's crt() function.
Returns CRT value with moduli instead of just CRT.
"" Args:
a, b: Remainder 1 and 2
m, n: Modulus 1 and 2
"""
def crt_2(a, b, m=None, n=None):
try:
f = (b-a).quo_rem
except (TypeError, AttributeError):
# Maybe there is no coercion between a and b.
# Maybe (b-a) does not have a quo_rem attribute
a = py_scalar_to_element(a)
b = py_scalar_to_element(b)
f = (b-a).quo_rem
g, alpha, beta = xgcd(m, n)
q, r = f(g)
if r != 0:
raise ValueError("No solution to crt problem since gcd(%s,%s) does not divide %s-%s" % (m, n, a, b))
x = a + q*alpha*py_scalar_to_element(m)
l = lcm(m, n)
return x % l, l
"""
/function/ crt_():
"" Purpose:
Modified from Sage's CRT_list() function.
Returns CRT value with moduli instead of just CRT.
"" Args:
r: List of remainders.
m: List of modulus.
"""
def crt_(r, m):
res = r[0]
prod = m[0]
for i in range(1, len(r)):
res, prod = crt_2(res, r[i], prod, m[i])
return res % prod, prod
"""
/function/ CRT_():
"" Purpose:
Function that calculates CRT on chunks of <SEG_SIZE> numbers in the array
rather than the whole array at once. The goal is to reduce long multiplication
time.
Works pretty nice with big array and (probably) big number.
Test: 65537 1024-bit numbers on a single core -> 3 mins 15 seconds.
Test: 65537 1024-bit numbers on 8 cores -> 1 mins 2 seconds.
Built on the base of Sage's crt() function.
"" Args:
r: List of remainders.
m: List of modulus.
SEG_SIZE=12: Number of values that CRT should work on once at a time.
NO_CORES=8: Number of cores in your machine.
debug=False: Print out some debug data.
"""
def CRT_(r: list, m: list, SEG_SIZE=12, NO_CORES=8, debug=False):
assert len(r) == len(m) >= 2
assert SEG_SIZE > 1
if debug:
print(f'[ i ] Calculate CRT with chunk size {SEG_SIZE}...')
print(f'[ i ] Start loop with len = {len(r)}')
with concurrent.futures.ProcessPoolExecutor(NO_CORES) as executor:
while len(r) != 1:
newR = []
newM = []
futures = []
for i in range(0, len(r), SEG_SIZE):
if len(r) - i == 1:
newR.append(r[i])
newM.append(m[i])
else:
futures.append(executor.submit(crt_, r[i:i+SEG_SIZE], m[i:i+SEG_SIZE]))
# Obtain processes' results :3
for future in futures:
result = future.result()
newR.append(result[0])
newM.append(result[1])
r = newR
m = newM
if debug:
print(f'[ i ] Update loop with len = {len(r)}')
if debug:
print(f'[ i ] Finished :D')
return r[0]
# Got CRT
from c import c
from n import n
c = c[:65537]
n = n[:65537]
# m_65537 = crt(c, n) # <- too slow !
m_65537 = CRT_(c, n, debug=True, SEG_SIZE=8, NO_CORES=8)
print(long_to_bytes(Integer(m_65537).nth_root(65537)))
With this final piece of code, we have reduced about 20 seconds worth of runtime. Which is actually pretty crazy, considering that the intended solution for this challenge was running the code for about 3 hours by writing optimised CRT() code in Julia instead of Python (confirmed by cothan himself)!! Now, a low-end user can run this algorithm within minutes, seconds even! With a machine running with many more cores, it can perform this task like a breeze!